import torch.nn as nn
from transformers import BertModel
# Model
class BertForIntentClassification(nn.Module):
    def __init__(self, config) -> None:
        super(BertForIntentClassification, self).__init__()
        self.config = config # 模型结构配置
        self.bert = BertModel.from_pretrained(config.pretrained_model_path) # 加载预训练模型
        self.bert_config = self.bert.config
        # 领域分类
        self.domain_classification = nn.Sequential(
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.hidden_size, self.config.num_domain_labels)
        )
        # 意图分类
        self.intent_classification = nn.Sequential( 
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.hidden_size, self.config.num_intent_labels)
        )
        # 槽填充
        self.slot_classification = nn.Sequential(
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.hidden_size, self.config.num_slot_labels),
        )

    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert(input_ids, 
                                attention_mask=attention_mask, 
                                token_type_ids=token_type_ids)
        domain_output = self.domain_classification(bert_output[1])
        intent_output = self.intent_classification(bert_output[1]) # (batch_size, num_intent_labels)
        slot_output = self.slot_classification(bert_output[0]) # (batch_size, seq_len, num_slot_labels)
        return domain_output, intent_output, slot_output        